package com.olive.store;


import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.stereotype.Component;

import com.alibaba.fastjson.JSONObject;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.olive.dto.JwtUserDto;
import com.olive.entity.Role;
import com.olive.entity.User;

import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;

@Component
public class JsonRedisTokenStore implements TokenStore {
	private static final String ACCESS = "access:";
	private static final String AUTH_TO_ACCESS = "auth_to_access:";
	private static final String AUTH = "auth:";
	private static final String REFRESH_AUTH = "refresh_auth:";
	private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
	private static final String REFRESH = "refresh:";
	private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
	private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
	private static final String UNAME_TO_ACCESS = "uname_to_access:";

	private final static Cache<Object, Object> CACHE;

	static {
		CACHE = Caffeine.newBuilder()
				.expireAfterWrite(60, TimeUnit.SECONDS)
				.maximumSize(1000).build();
	}


	private final StringRedisTemplate stringRedisTemplate;
	private final AuthenticationKeyGenerator authenticationKeyGenerator = new CustomAuthenticationKeyGenerator();


	public JsonRedisTokenStore(StringRedisTemplate stringRedisTemplate) {
		this.stringRedisTemplate = stringRedisTemplate;
	}


	@Override
	public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
		return this.readAuthentication(token.getValue());
	}

	//根据AccessToken对象查询对应的OAuth2Authentication（认证的用户信息）
	@Override
	public OAuth2Authentication readAuthentication(String token) {
		String key = AUTH + token;

		return (OAuth2Authentication) loadCache(key, (k) -> {
			String json = stringRedisTemplate.opsForValue().get(key);
			if (StrUtil.isBlank(json)) {
				return null;
			}
			return fullParseJSON(json);
		});
	}

	/**
	 * 完整的OAuth2Authentication 对象转换
	 *
	 * @param json 完整OAuth2Authentication json字符串
	 * @return OAuth2Authentication对象
	 */
	private OAuth2Authentication fullParseJSON(String json) {
		JSONObject jsonObject = JSONObject.parseObject(json);

		JSONObject userAuthenticationObject = jsonObject.getJSONObject("userAuthentication");
		/**
		 * 这里通过之前写的UserDetails的构造方法，我的UserDetails由三部分构造：User类，Set<Role>,List<String> authorityName
		 * 所以这里你可以直接按照你自己的userDetails来进行json解析
		 * 而UserDetails在这里就是Principal部分，所以json得先解析principal才能得到我需要的UserDetails组件
		 */
		
		User userInfo = userAuthenticationObject.getJSONObject("principal").getObject("user",User.class);
		Set<Role> roleInfo = new HashSet<>(userAuthenticationObject.getJSONObject("principal").getJSONArray("roleInfo").toJavaList(Role.class));
		List<String> authorityNames = userAuthenticationObject.getJSONObject("principal").getJSONArray("authorityNames").toJavaList(String.class);
		JwtUserDto jwtUserDto = new JwtUserDto(userInfo, roleInfo, authorityNames);
		String credentials = userAuthenticationObject.getString("credentials");
		JSONObject detailsJSONObject = userAuthenticationObject.getJSONObject("details");
		LinkedHashMap<String, Object> details = new LinkedHashMap<>();
		for (String key : detailsJSONObject.keySet()) {
			details.put(key, detailsJSONObject.get(key));
		}

		UsernamePasswordAuthenticationToken userAuthentication = new UsernamePasswordAuthenticationToken(jwtUserDto
				, credentials, new ArrayList<>(0));
		userAuthentication.setDetails(details);

		JSONObject storedRequest = jsonObject.getJSONObject("oAuth2Request");
		String clientId = storedRequest.getString("clientId");

		JSONObject requestParametersJSON = storedRequest.getJSONObject("requestParameters");
		Map<String, String> requestParameters = new HashMap<>();
		for (String key : requestParametersJSON.keySet()) {
			requestParameters.put(key, requestParametersJSON.getString(key));
		}

		Set<String> scope = convertSetString(storedRequest, "scope");
		Set<String> resourceIds = convertSetString(storedRequest, "resourceIds");
		Set<String> responseTypes = convertSetString(storedRequest, "responseTypes");

		OAuth2Request oAuth2Request = new OAuth2Request(requestParameters
				, clientId
				//由于这个项目不需要处理权限角色，所以就没有对权限角色集合做处理
				, new ArrayList<>(0)
				, storedRequest.getBoolean("approved")
				, scope
				, resourceIds
				, storedRequest.getString("redirectUri")
				, responseTypes
				, null //extensionProperties
		);

		return new OAuth2Authentication(oAuth2Request, userAuthentication);
	}

	private static Set<String> convertSetString(JSONObject data, String key) {
		List<String> list = data.getJSONArray(key).toJavaList(String.class);

		return new HashSet<>(list);
	}

	//存储accessToken并且存储用户认证信息（Principal)
	@Override
	public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
		String serializedAccessToken = JSONObject.toJSONString(token);
		String serializedAuth = JSONObject.toJSONString(authentication);
		String accessKey = ACCESS + token.getValue();
		String authKey = AUTH + token.getValue();
		String authToAccessKey = AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication);
		String approvalKey = UNAME_TO_ACCESS + getApprovalKey(authentication);
		String clientId = CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId();

		int seconds = 30 * 24 * 60 * 60;
		if (token.getExpiration() != null) {
			seconds = (int) DateUtil.between(new Date(),token.getExpiration(), DateUnit.SECOND);
		}

		try {
			stringRedisTemplate.opsForValue().set(accessKey, serializedAccessToken);
			stringRedisTemplate.opsForValue().set(authKey, serializedAuth);
			stringRedisTemplate.opsForValue().set(authToAccessKey, serializedAccessToken);

			if (!authentication.isClientOnly()) {
				stringRedisTemplate.opsForHash().putIfAbsent(approvalKey, token.getValue(), serializedAccessToken);
			}
		} finally {
			//如果中途失败，则还可以补偿过期时间
			stringRedisTemplate.expire(accessKey, seconds, TimeUnit.SECONDS);
			stringRedisTemplate.expire(authKey, seconds, TimeUnit.SECONDS);
			stringRedisTemplate.expire(authToAccessKey, seconds, TimeUnit.SECONDS);
			stringRedisTemplate.expire(clientId, seconds, TimeUnit.SECONDS);
			stringRedisTemplate.expire(approvalKey, seconds, TimeUnit.SECONDS);
		}

		OAuth2RefreshToken refreshToken = token.getRefreshToken();
		if (refreshToken != null && refreshToken.getValue() != null) {
			String refreshValue = token.getRefreshToken().getValue();
			String refreshToAccessKey = REFRESH_TO_ACCESS + refreshValue;
			String accessToRefreshKey = ACCESS_TO_REFRESH + token.getValue();

			try {
				stringRedisTemplate.opsForValue().set(refreshToAccessKey, token.getValue());
				stringRedisTemplate.opsForValue().set(accessToRefreshKey, refreshValue);
			} finally {
				//如果中途失败，则还可以补偿过期时间
				refreshTokenProcess(refreshToken, refreshToAccessKey, accessToRefreshKey);
			}

			CACHE.put(refreshToAccessKey, token.getValue());
			CACHE.put(accessToRefreshKey, refreshValue);
		}

		CACHE.put(accessKey, token);
		CACHE.put(authKey, authentication);
		CACHE.put(authToAccessKey, token);
	}

	private void refreshTokenProcess(OAuth2RefreshToken refreshToken, String refreshKey, String refreshAuthKey) {
		int seconds = 30 * 24 * 60 * 60;
		if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
			ExpiringOAuth2RefreshToken expiringRefreshToken = (ExpiringOAuth2RefreshToken) refreshToken;
			Date expiration = expiringRefreshToken.getExpiration();

			int temp;
			if (expiration != null) {
				temp = Long.valueOf((expiration.getTime() - System.currentTimeMillis()) / 1000L)
						.intValue();

			} else {
				temp = seconds;
			}
			stringRedisTemplate.expire(refreshKey, temp, TimeUnit.SECONDS);
			stringRedisTemplate.expire(refreshAuthKey, temp, TimeUnit.SECONDS);
		}
	}


	private String getApprovalKey(OAuth2Authentication authentication) {
		String userName = "";
		if (authentication.getUserAuthentication() != null) {
			JwtUserDto userInfoDetails = (JwtUserDto) authentication.getUserAuthentication().getPrincipal();
			userName = userInfoDetails.getUser().getMobile() + "_" + userInfoDetails.getUsername();
		}

		return getApprovalKey(authentication.getOAuth2Request().getClientId(), userName);
	}

	private String getApprovalKey(String clientId, String userName) {
		return clientId + (userName == null ? "" : ":" + userName);
	}


	//根据AccessToken的value值查询对应的token对象
	@Override
	public OAuth2AccessToken readAccessToken(String tokenValue) {
		String key = ACCESS + tokenValue;
		//先从本地缓存取，没有再从redis中取，都没有返回Null
		return (OAuth2AccessToken) loadCache(key, (k) -> {
			String json = stringRedisTemplate.opsForValue().get(key);
			if (StrUtil.isNotBlank(json)) {
				return JSONObject.parseObject(json, DefaultOAuth2AccessTokenEx.class);
			}
			return null;
		});
	}

	@Override
	public void removeAccessToken(OAuth2AccessToken accessToken) {
		removeAccessToken(accessToken.getValue());
	}

	public void removeAccessToken(String tokenValue) {
		String accessKey = ACCESS + tokenValue;
		String authKey = AUTH + tokenValue;
		String accessToRefreshKey = ACCESS_TO_REFRESH + tokenValue;

		OAuth2Authentication authentication = readAuthentication(tokenValue);
		String access = stringRedisTemplate.opsForValue().get(accessKey);

		List<String> keys = new ArrayList<>(6);
		keys.add(accessKey);
		keys.add(authKey);
		keys.add(accessToRefreshKey);

		stringRedisTemplate.delete(keys);

		if (authentication != null) {
			String key = authenticationKeyGenerator.extractKey(authentication);
			String authToAccessKey = AUTH_TO_ACCESS + key;
			String unameKey = UNAME_TO_ACCESS + getApprovalKey(authentication);
			String clientId = CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId();

			stringRedisTemplate.delete(authToAccessKey);
			stringRedisTemplate.opsForHash().delete(unameKey, tokenValue);
			stringRedisTemplate.opsForList().remove(clientId, 1, access);
			stringRedisTemplate.delete(ACCESS + key);

			CACHE.invalidate(authToAccessKey);
			CACHE.invalidate(ACCESS + key);
		}

		CACHE.invalidateAll(keys);
	}

	@Override
	public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
		String refreshKey = REFRESH + refreshToken.getValue();
		String refreshAuthKey = REFRESH_AUTH + refreshToken.getValue();
		String serializedRefreshToken = JSONObject.toJSONString(refreshToken);

		stringRedisTemplate.opsForValue().set(refreshKey, serializedRefreshToken);
		stringRedisTemplate.opsForValue().set(refreshAuthKey, JSONObject.toJSONString(authentication));

		refreshTokenProcess(refreshToken, refreshKey, refreshAuthKey);

		CACHE.put(refreshKey, refreshToken);
		CACHE.put(refreshAuthKey, authentication);
	}


	//和readAccessToken的原理一致
	@Override
	public OAuth2RefreshToken readRefreshToken(String tokenValue) {
		String key = REFRESH + tokenValue;
		return (OAuth2RefreshToken) loadCache(key, (k) -> {
			String json = stringRedisTemplate.opsForValue().get(key);
			if (StrUtil.isNotBlank(json)) {
				return JSONObject.parseObject(json, DefaultOAuth2RefreshTokenEx.class);
			}

			return null;
		});
	}

	@Override
	public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
		return this.readAuthenticationForRefreshToken(token.getValue());
	}

	public OAuth2Authentication readAuthenticationForRefreshToken(String token) {
		String key = REFRESH_AUTH + token;

		return (OAuth2Authentication) loadCache(key, (k) -> {
			String json = stringRedisTemplate.opsForValue().get(key);
			if (StrUtil.isBlank(json)) {
				return null;
			}

			return fullParseJSON(json);
		});
	}

	@Override
	public void removeRefreshToken(OAuth2RefreshToken refreshToken) {
		this.removeRefreshToken(refreshToken.getValue());
	}

	public void removeRefreshToken(String refreshToken) {
		String refreshKey = REFRESH + refreshToken;
		String refreshAuthKey = REFRESH_AUTH + refreshToken;
		String refresh2AccessKey = REFRESH_TO_ACCESS + refreshToken;
		String access2RefreshKey = ACCESS_TO_REFRESH + refreshToken;

		List<String> keys = new ArrayList<>(7);
		keys.add(refreshKey);
		keys.add(refreshAuthKey);
		keys.add(refresh2AccessKey);
		keys.add(access2RefreshKey);

		stringRedisTemplate.delete(keys);

		CACHE.invalidateAll(keys);
	}

	@Override
	public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
		this.removeAccessTokenUsingRefreshToken(refreshToken.getValue());
	}

	private void removeAccessTokenUsingRefreshToken(String refreshToken) {
		String key = REFRESH_TO_ACCESS + refreshToken;

		String accessToken = stringRedisTemplate.opsForValue().get(key);
		stringRedisTemplate.delete(key);

		if (accessToken != null) {
			removeAccessToken(accessToken);
		}

		CACHE.invalidate(key);
	}

	private <T> Object loadCache(String key, Function<Object, ? extends T> loadData) {
		try {
			Object value = CACHE.getIfPresent(key);
			if (value == null) {
				value = loadData.apply(key);
				//如果redis中有则将redis中的token放入本地缓存中
				if (value != null) {
					CACHE.put(key, value);
				}
			}

			return value;
		} catch (Exception e) {
			throw new RuntimeException("JsonRedisTokenStore.loadCache从缓存中加载数据发生错误", e);
		}
	}

	@Override
	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
		String key = AUTH_TO_ACCESS + authenticationKeyGenerator.extractKey(authentication);

		return (OAuth2AccessToken) loadCache(key, (k) -> {
			String json = stringRedisTemplate.opsForValue().get(key);

			if (StrUtil.isNotBlank(json)) {
				DefaultOAuth2AccessToken accessToken = JSONObject.parseObject(json, DefaultOAuth2AccessTokenEx.class);


				OAuth2Authentication storedAuthentication = readAuthentication(accessToken.getValue());

				if (storedAuthentication == null
						|| !key.equals(authenticationKeyGenerator.extractKey(storedAuthentication))) {
					this.storeAccessToken(accessToken, authentication);
				}

				return accessToken;
			}

			return null;
		});
	}



	@Override
	public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
		String approvalKey = UNAME_TO_ACCESS + getApprovalKey(clientId, userName);

		return getOAuth2AccessTokens(approvalKey);
	}

	private Collection<OAuth2AccessToken> getOAuth2AccessTokens(String approvalKey) {
		return (Collection<OAuth2AccessToken>) loadCache(approvalKey, (k) -> {
			Map<Object, Object> accessTokens = stringRedisTemplate.opsForHash().entries(approvalKey);

			if (accessTokens.size() == 0) {
				return Collections.emptySet();
			}

			List<OAuth2AccessToken> result = new ArrayList<>();

			for (Object json : accessTokens.values()) {
				String strJSON = json.toString();
				OAuth2AccessToken accessToken = JSONObject.parseObject(strJSON, DefaultOAuth2AccessTokenEx.class);

				result.add(accessToken);
			}
			return Collections.unmodifiableCollection(result);
		});
	}

	@Override
	public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
		String key = CLIENT_ID_TO_ACCESS + clientId;

		return getOAuth2AccessTokens(key);
	}
}

